{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Automatically generating object masks with SAM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook walks through how to automatically segment objects in an image. It is modified from [original SAM GitHub repo](https://github.com/facebookresearch/segment-anything/).\n",
    "\n",
    "Since SAM can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image. This method was used to generate the dataset SA-1B. \n",
    "\n",
    "The class `SamAutomaticMaskGenerator` implements this. It samples single-point input prompts in a grid over the image, from each of which SAM then predicts multiple masks. The masks are filtered for quality and deduplicated using non-max suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set-up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_anns(anns):\n",
    "    if len(anns) == 0:\n",
    "        return\n",
    "    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)\n",
    "    ax = plt.gca()\n",
    "    ax.set_autoscale_on(False)\n",
    "\n",
    "    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))\n",
    "    img[:,:,3] = 0\n",
    "    for ann in sorted_anns:\n",
    "        m = ann['segmentation']\n",
    "        color_mask = np.concatenate([np.random.random(3), [0.35]])\n",
    "        img[m] = color_mask\n",
    "    ax.imshow(img)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image = cv2.imread('images/dog.jpg')\n",
    "image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20,20))\n",
    "plt.imshow(image)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Automatic mask generation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "from segment_anything import SamAutomaticMaskGenerator\n",
    "from segment_anything.sam import load\n",
    "\n",
    "sam_checkpoint = \"../sam-vit-base\"\n",
    "sam = load(sam_checkpoint)\n",
    "\n",
    "mask_generator = SamAutomaticMaskGenerator(sam)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To generate masks, run `generate` on an image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks = mask_generator.generate(image)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Mask generation returns a list over masks. Each item is a dictionary with keys:\n",
    "* `segmentation` : the mask\n",
    "* `area` : the area of the mask in pixels\n",
    "* `bbox` : the boundary box of the mask in XYWH format\n",
    "* `predicted_iou` : the model's own prediction for the quality of the mask\n",
    "* `point_coords` : the sampled input point that generated this mask\n",
    "* `stability_score` : an additional measure of mask quality\n",
    "* `crop_box` : the crop of the image used to generate this mask in XYWH format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(masks))\n",
    "print(masks[0].keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Show all the masks overlayed on the image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20,20))\n",
    "plt.imshow(image)\n",
    "show_anns(masks)\n",
    "plt.axis('off')\n",
    "plt.show() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Automatic mask generation options"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Generation can be automatically run on crops of the image to get better results for smaller objects. Post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mask_generator_2 = SamAutomaticMaskGenerator(\n",
    "    model=sam,\n",
    "    points_per_side=32,\n",
    "    pred_iou_thresh=0.86,\n",
    "    stability_score_thresh=0.92,\n",
    "    crop_n_layers=1,\n",
    "    crop_n_points_downscale_factor=2,\n",
    "    min_mask_region_area=100,  # Requires open-cv to run post-processing\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "masks2 = mask_generator_2.generate(image)\n",
    "len(masks2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20,20))\n",
    "plt.imshow(image)\n",
    "show_anns(masks2)\n",
    "plt.axis('off')\n",
    "plt.show() "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
